{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "\n",
    "tokenizer = T5Tokenizer.from_pretrained('mesolitica/t5-small-standard-bahasa-cased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['finetune-t5-tiny-standard-bahasa-cased/checkpoint-280000',\n",
       " 'finetune-t5-tiny-standard-bahasa-cased/checkpoint-290000',\n",
       " 'finetune-t5-tiny-standard-bahasa-cased/checkpoint-300000']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from glob import glob\n",
    "\n",
    "checkpoints = sorted(glob('finetune-t5-tiny-standard-bahasa-cased/checkpoint-*'))\n",
    "checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from unidecode import unidecode\n",
    "import re\n",
    "\n",
    "# minimum cleaning, just simply to remove newlines.\n",
    "def cleaning(string):\n",
    "    string = string.replace('\\n', ' ')\n",
    "    string = re.sub(r'[ ]+', ' ', string).strip()\n",
    "    return string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "p_wikipedia = \"\"\"\n",
    "Najib razak telah dipilih untuk Parlimen Malaysia pada tahun 1976,\n",
    "pada usia 23 tahun, menggantikan bapanya duduk di kerusi Pekan yang berpangkalan di Pahang.\n",
    "Dari tahun 1982 hingga 1986 beliau menjadi Menteri Besar (Ketua Menteri) Pahang,\n",
    "sebelum memasuki persekutuan Kabinet Tun Dr Mahathir Mohamad pada tahun 1986 sebagai Menteri Kebudayaan, Belia dan Sukan.\n",
    "Beliau telah berkhidmat dalam pelbagai jawatan Kabinet sepanjang baki tahun 1980-an dan 1990-an, termasuk sebagai Menteri Pertahanan dan Menteri Pelajaran.\n",
    "Beliau menjadi Timbalan Perdana Menteri pada 7 Januari 2004, berkhidmat di bawah Perdana Menteri Tun Dato' Seri Abdullah Ahmad Badawi,\n",
    "sebelum menggantikan Badawi setahun selepas Barisan Nasional mengalami kerugian besar dalam pilihan raya 2008.\n",
    "Di bawah kepimpinan beliau, Barisan Nasional memenangi pilihan raya 2013,\n",
    "walaupun buat kali pertama dalam sejarah Malaysia pembangkang memenangi majoriti undi popular.\n",
    "\"\"\"\n",
    "q_wikipedia = ['bilakah najib dipilih untuk parlimen malaysia', \n",
    "               'Apakah jawatan yang pernah dipegang oleh Najib Razak?']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bilakah najib dipilih untuk parlimen malaysia : 1976\n",
      "Apakah jawatan yang pernah dipegang oleh Najib Razak? : Kabinet\n"
     ]
    }
   ],
   "source": [
    "text = cleaning(p_wikipedia)\n",
    "input_ids = []\n",
    "for q in q_wikipedia:\n",
    "    s = f'konteks: {text} soalan: {q}'\n",
    "    # print(s)\n",
    "    # s = f'teks: {text} entiti: {t}'\n",
    "    input_ids.append({'input_ids': tokenizer.encode(s, return_tensors='pt')[0]})\n",
    "\n",
    "padded = tokenizer.pad(input_ids, padding='longest')\n",
    "outputs = model.generate(**padded, max_length = 256, num_beams=5, \n",
    "    early_stopping=True)\n",
    "b = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "for no, q in enumerate(q_wikipedia):\n",
    "    print(q, ':', b[no])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "p_news = \"\"\"\n",
    "Bekas perdana menteri Najib Razak mempersoalkan tindakan polis yang menurutnya tidak serta-merta mengeluarkan kenyataan berhubung dakwaan Adun Perikatan Nasional (PN) \"merancang\" insiden rogol.\n",
    "Sedangkan, kata ahli parlimen Pekan itu, polis pantas mengeluarkan kenyataan apabila dia dilapor terlupa mengimbas MySejahtera sebelum masuk restoran.\n",
    "\"Berita Najib lupa scan MySejahtera tular, kenyataan polis terus keluar. Berita Dr Mahathir Mohamad lupa scan, kenyataan, polis serta-merta keluar.\n",
    "\"Sebab itu saya pelik kenapa pihak polis belum sempat keluar apa-apa kenyataan berhubung kes seorang gadis membuat laporan polis untuk dakwa Adun PN rancang insiden rogolnya,\" katanya di Facebook hari ini.\n",
    "Najib merujuk dakwaan seorang wanita yang mendakwa dirogol kenalan kepada Adun Gombak Setia, Hilman Idham.\n",
    "Wanita itu mendakwa ahli politik dari Bersatu berkenaan merancang insiden yang berlaku pada 5 Dis lalu.\n",
    "Menurut laporan polis pada 8 Mei, mangsa mendakwa kejadian itu berlaku di sebuah hotel di Selangor, yang pada masa itu berada di bawah perintah kawalan pergerakan bersyarat (PKPB).\n",
    "\"\"\"\n",
    "\n",
    "q_news = ['siapakah yang mempersoalkan tindakan polis', 'siapakah Adun Gombak Setia',\n",
    "         'siapakah ahli perlimen Pekan?', 'who is adun gombak setia',\n",
    "         'bilakah tarikh laporan polis tentang kejadian di sebuah hotel di Selangor']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "siapakah yang mempersoalkan tindakan polis : Bekas perdana menteri Najib Razak\n",
      "siapakah Adun Gombak Setia : Hilman Idham\n",
      "siapakah ahli perlimen Pekan? : parlimen Pekan itu\n",
      "who is adun gombak setia : Hilman Idham\n",
      "bilakah tarikh laporan polis tentang kejadian di sebuah hotel di Selangor : 8 Mei\n"
     ]
    }
   ],
   "source": [
    "text = cleaning(p_news)\n",
    "input_ids = []\n",
    "for q in q_news:\n",
    "    s = f'konteks: {text} soalan: {q}'\n",
    "    # print(s)\n",
    "    # s = f'teks: {text} entiti: {t}'\n",
    "    input_ids.append({'input_ids': tokenizer.encode(s, return_tensors='pt')[0]})\n",
    "\n",
    "padded = tokenizer.pad(input_ids, padding='longest')\n",
    "outputs = model.generate(**padded, max_length = 256)\n",
    "b = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "for no, q in enumerate(q_news):\n",
    "    print(q, ':', b[no])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a705e336aea141d0824a75858374d5de",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file pytorch_model.bin:   0%|          | 32.0k/133M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "remote: Scanning LFS files for validity, may be slow...        \n",
      "remote: LFS file scan complete.        \n",
      "To https://huggingface.co/mesolitica/finetune-extractive-qa-t5-tiny-standard-bahasa-cased\n",
      "   d01ba07..da4caf7  main -> main\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'https://huggingface.co/mesolitica/finetune-extractive-qa-t5-tiny-standard-bahasa-cased/commit/da4caf7c08931a6731ca7e9d666611d9ce6cebbd'"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.push_to_hub('finetune-extractive-qa-t5-tiny-standard-bahasa-cased', organization='mesolitica')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c19641dc78d24b3d89e1f8d1552bd601",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file spiece.model:   4%|4         | 32.0k/784k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "remote: Scanning LFS files for validity, may be slow...        \n",
      "remote: LFS file scan complete.        \n",
      "To https://huggingface.co/mesolitica/finetune-extractive-qa-t5-tiny-standard-bahasa-cased\n",
      "   da4caf7..0d0d252  main -> main\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'https://huggingface.co/mesolitica/finetune-extractive-qa-t5-tiny-standard-bahasa-cased/commit/0d0d252f5661def960218e61fa956d51d15cc37a'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.push_to_hub('finetune-extractive-qa-t5-tiny-standard-bahasa-cased', organization='mesolitica')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import collections\n",
    "import string\n",
    "\n",
    "def normalize_answer(s):\n",
    "    \"\"\"Lower text and remove punctuation, articles and extra whitespace.\"\"\"\n",
    "    def remove_articles(text):\n",
    "        regex = re.compile(r'\\b(a|an|the)\\b', re.UNICODE)\n",
    "        return re.sub(regex, ' ', text)\n",
    "\n",
    "    def white_space_fix(text):\n",
    "        return ' '.join(text.split())\n",
    "\n",
    "    def remove_punc(text):\n",
    "        exclude = set(string.punctuation)\n",
    "        return ''.join(ch for ch in text if ch not in exclude)\n",
    "\n",
    "    def lower(text):\n",
    "        return text.lower()\n",
    "    return white_space_fix(remove_articles(remove_punc(lower(s))))\n",
    "\n",
    "\n",
    "def get_tokens(s):\n",
    "    if not s:\n",
    "        return []\n",
    "    return normalize_answer(s).split()\n",
    "\n",
    "\n",
    "def compute_exact(a_gold, a_pred):\n",
    "    return int(normalize_answer(a_gold) == normalize_answer(a_pred))\n",
    "\n",
    "\n",
    "def compute_f1(a_gold, a_pred):\n",
    "    gold_toks = get_tokens(a_gold)\n",
    "    pred_toks = get_tokens(a_pred)\n",
    "    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)\n",
    "    num_same = sum(common.values())\n",
    "    if len(gold_toks) == 0 or len(pred_toks) == 0:\n",
    "        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise\n",
    "        return int(gold_toks == pred_toks)\n",
    "    if num_same == 0:\n",
    "        return 0\n",
    "    precision = 1.0 * num_same / len(pred_toks)\n",
    "    recall = 1.0 * num_same / len(gold_toks)\n",
    "    f1 = (2 * precision * recall) / (precision + recall)\n",
    "    return f1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('ms-dev-2.0.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:05<00:00,  6.45it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "srcs, answers = [], []\n",
    "for i in tqdm(range(len(data['data']))):\n",
    "    for p in data['data'][i]['paragraphs']:\n",
    "        text = p['context']\n",
    "        if len(text.split()) > 500:\n",
    "            continue\n",
    "        \n",
    "        for q in p['qas']:\n",
    "            qs = q['question']\n",
    "            is_impossible = q.get('is_impossible', False)\n",
    "            if is_impossible:\n",
    "                a = 'tiada jawapan'\n",
    "            else:\n",
    "                a = q['answers'][0]['text']\n",
    "\n",
    "            if not len(a):\n",
    "                a = 'tiada jawapan'\n",
    "\n",
    "            src = f'konteks: {text} soalan: {qs}'\n",
    "            input_ids = {'input_ids': tokenizer.encode(src, return_tensors='pt')}\n",
    "            srcs.append(input_ids)\n",
    "            answers.append(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 11849/11849 [08:17<00:00, 23.83it/s]\n"
     ]
    }
   ],
   "source": [
    "exact, f1 = [], []\n",
    "\n",
    "for i in tqdm(range(len(srcs))):\n",
    "    a = answers[i]\n",
    "    outputs = model.generate(**srcs[i], max_length = 256)\n",
    "    b = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]\n",
    "    f1.append(compute_f1(a, b))\n",
    "    exact.append(compute_exact(a, b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.4269558612541143, 0.5113033923646412)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "np.mean(exact), np.mean(f1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('dev-v2.0.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:05<00:00,  6.43it/s]\n"
     ]
    }
   ],
   "source": [
    "srcs, answers = [], []\n",
    "for i in tqdm(range(len(data['data']))):\n",
    "    for p in data['data'][i]['paragraphs']:\n",
    "        text = p['context']\n",
    "        if len(text.split()) > 500:\n",
    "            continue\n",
    "        \n",
    "        for q in p['qas']:\n",
    "            qs = q['question']\n",
    "            is_impossible = q.get('is_impossible', False)\n",
    "            if is_impossible:\n",
    "                a = 'tiada jawapan'\n",
    "            else:\n",
    "                a = q['answers'][0]['text']\n",
    "\n",
    "            if not len(a):\n",
    "                a = 'tiada jawapan'\n",
    "\n",
    "            src = f'konteks: {text} soalan: {qs}'\n",
    "            input_ids = {'input_ids': tokenizer.encode(src, return_tensors='pt')}\n",
    "            srcs.append(input_ids)\n",
    "            answers.append(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 11840/11840 [08:39<00:00, 22.78it/s]\n"
     ]
    }
   ],
   "source": [
    "exact, f1 = [], []\n",
    "\n",
    "for i in tqdm(range(len(srcs))):\n",
    "    a = answers[i]\n",
    "    outputs = model.generate(**srcs[i], max_length = 256)\n",
    "    b = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]\n",
    "    f1.append(compute_f1(a, b))\n",
    "    exact.append(compute_exact(a, b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.46925675675675677, 0.5410633845304322)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(exact), np.mean(f1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
